---
title: "Text-to-image"
description: "Building a text-to-image model with SDXL"
---


        <Card
          title="View on Github"
          icon="github" href="https://github.com/basetenlabs/truss-examples/tree/main/04-image-generation">
        </Card>

In this example, we go through a Truss that serves a text-to-image model. We
use SDXL 1.0, which is one of the highest performing text-to-image models out
there today.

# Set up imports and torch settings

In this example, we use the Huggingface diffusers library to build our text-to-image model.

```python model/model.py
import base64
import time
from io import BytesIO
from typing import Any

import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image

```
The following line is needed to enable TF32 on NVIDIA GPUs

```python model/model.py
torch.backends.cuda.matmul.allow_tf32 = True

```
# Define the `Model` class and load function

In the `load` function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
`DiffusionPipeline` class in `diffusers` to instantiate our SDXL pipeline,
and configure a number of relevant parameters.

See the [diffusers docs](https://huggingface.co/docs/diffusers/index) for details
on all of these parameters.

```python model/model.py
class Model:
    def __init__(self, **kwargs):
        self._model = None

    def load(self):
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
        )
        self.pipe = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            vae=vae,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        )

        self.pipe.unet.to(memory_format=torch.channels_last)
        self.pipe.to("cuda")
        self.pipe.enable_xformers_memory_efficient_attention()

        self.refiner = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            text_encoder_2=self.pipe.text_encoder_2,
            vae=self.pipe.vae,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16",
        )
        self.refiner.to("cuda")
        self.refiner.enable_xformers_memory_efficient_attention()

```
This is a utility function for converting PIL image to base64.

```python model/model.py
    def convert_to_b64(self, image: Image) -> str:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return img_b64

```
# Define the predict function

The `predict` function contains the actual inference logic. The steps here are:
  * Setting up the generation params. We have defaults for these, and some, such
as the `scheduler`, are somewhat complicated
  * Running the Diffusion Pipeline
  * If `use_refiner` is set to `True`, we run the refiner model on the output
  * Convert the resulting image to base64 and return it

```python model/model.py
    def predict(self, model_input: Any) -> Any:
        prompt = model_input.pop("prompt")
        negative_prompt = model_input.pop("negative_prompt", None)
        use_refiner = model_input.pop("use_refiner", True)
        num_inference_steps = model_input.pop("num_inference_steps", 30)
        denoising_frac = model_input.pop("denoising_frac", 0.8)
        end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
        guidance_scale = model_input.pop("guidance_scale", 7.5)
        seed = model_input.pop("seed", None)

        scheduler = model_input.pop(
            "scheduler", None
        )  # Default: EulerDiscreteScheduler (works pretty well)

```
Set the scheduler based on the user's input.
See possible schedulers: https://huggingface.co/docs/diffusers/api/schedulers/overview for
what the tradeoffs are.

```python model/model.py
        if scheduler == "DPM++ 2M":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config
            )
        elif scheduler == "DPM++ 2M Karras":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config, use_karras_sigmas=True
            )
        elif scheduler == "DPM++ 2M SDE Karras":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config,
                algorithm_type="sde-dpmsolver++",
                use_karras_sigmas=True,
            )

        generator = None
        if seed is not None:
            torch.manual_seed(seed)
            generator = [torch.Generator(device="cuda").manual_seed(seed)]

        if not use_refiner:
            denoising_frac = 1.0

        start_time = time.time()
        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            generator=generator,
            end_cfg=end_cfg_frac,
            num_inference_steps=num_inference_steps,
            denoising_end=denoising_frac,
            guidance_scale=guidance_scale,
            output_type="latent" if use_refiner else "pil",
        ).images[0]
        scheduler = self.pipe.scheduler
        if use_refiner:
            self.refiner.scheduler = scheduler
            image = self.refiner(
                prompt=prompt,
                negative_prompt=negative_prompt,
                generator=generator,
                end_cfg=end_cfg_frac,
                num_inference_steps=num_inference_steps,
                denoising_start=denoising_frac,
                guidance_scale=guidance_scale,
                image=image[None, :],
            ).images[0]

```
Convert the results to base64, and return them.

```python model/model.py
        b64_results = self.convert_to_b64(image)
        end_time = time.time() - start_time

        print(f"Time: {end_time:.2f} seconds")

        return {"status": "success", "data": b64_results, "time": end_time}
```

# Setting up the config yaml

Running SDXL requires a handful of Python libraries, including
diffusers, transformers, and others.

```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
  example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
```
## Configuring resources for SDXL 1.0

Note that we need an A10G to run this model.

```yaml config.yaml
resources:
  accelerator: A10G
  cpu: 3500m
  memory: 20Gi
  use_gpu: true
secrets: {}
```
## System Packages

Running diffusers requires `ffmpeg` and a couple other system
packages.

```yaml config.yaml
system_packages:
- ffmpeg
- libsm6
- libxext6
```
## Enabling Caching

SDXL is a very large model, and downloading it could take up to 10 minutes. This means
that the cold start time for this model is long. We can solve that by using our build
caching feature. This moves the model download to the build stage of your model--
caching the model will take about 10 minutes initially but you will get ~9s cold starts
subsequently.

To enable caching, add the following to the config:
```yaml
model_cache:
  - repo_id: madebyollin/sdxl-vae-fp16-fix
    allow_patterns:
      - config.json
      - diffusion_pytorch_model.safetensors
  - repo_id: stabilityai/stable-diffusion-xl-base-1.0
    allow_patterns:
      - "*.json"
      - "*.fp16.safetensors"
      - sd_xl_base_1.0.safetensors
  - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
    allow_patterns:
      - "*.json"
      - "*.fp16.safetensors"
      - sd_xl_refiner_1.0.safetensors
```
# Deploy the model

Deploy the model like you would other Trusses, with:
```bash
$ truss push
```
You can then invoke the model with:
```bash
$ truss predict -d '{"prompt": "A tree in a field under the night sky", "use_refiner": true}'
```

<RequestExample>
```python model/model.py
import base64
import time
from io import BytesIO
from typing import Any

import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image

torch.backends.cuda.matmul.allow_tf32 = True

class Model:
    def __init__(self, **kwargs):
        self._model = None

    def load(self):
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
        )
        self.pipe = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            vae=vae,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        )

        self.pipe.unet.to(memory_format=torch.channels_last)
        self.pipe.to("cuda")
        self.pipe.enable_xformers_memory_efficient_attention()

        self.refiner = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            text_encoder_2=self.pipe.text_encoder_2,
            vae=self.pipe.vae,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16",
        )
        self.refiner.to("cuda")
        self.refiner.enable_xformers_memory_efficient_attention()

    def convert_to_b64(self, image: Image) -> str:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return img_b64

    def predict(self, model_input: Any) -> Any:
        prompt = model_input.pop("prompt")
        negative_prompt = model_input.pop("negative_prompt", None)
        use_refiner = model_input.pop("use_refiner", True)
        num_inference_steps = model_input.pop("num_inference_steps", 30)
        denoising_frac = model_input.pop("denoising_frac", 0.8)
        end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
        guidance_scale = model_input.pop("guidance_scale", 7.5)
        seed = model_input.pop("seed", None)

        scheduler = model_input.pop(
            "scheduler", None
        )  # Default: EulerDiscreteScheduler (works pretty well)

        if scheduler == "DPM++ 2M":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config
            )
        elif scheduler == "DPM++ 2M Karras":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config, use_karras_sigmas=True
            )
        elif scheduler == "DPM++ 2M SDE Karras":
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
                self.pipe.scheduler.config,
                algorithm_type="sde-dpmsolver++",
                use_karras_sigmas=True,
            )

        generator = None
        if seed is not None:
            torch.manual_seed(seed)
            generator = [torch.Generator(device="cuda").manual_seed(seed)]

        if not use_refiner:
            denoising_frac = 1.0

        start_time = time.time()
        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            generator=generator,
            end_cfg=end_cfg_frac,
            num_inference_steps=num_inference_steps,
            denoising_end=denoising_frac,
            guidance_scale=guidance_scale,
            output_type="latent" if use_refiner else "pil",
        ).images[0]
        scheduler = self.pipe.scheduler
        if use_refiner:
            self.refiner.scheduler = scheduler
            image = self.refiner(
                prompt=prompt,
                negative_prompt=negative_prompt,
                generator=generator,
                end_cfg=end_cfg_frac,
                num_inference_steps=num_inference_steps,
                denoising_start=denoising_frac,
                guidance_scale=guidance_scale,
                image=image[None, :],
            ).images[0]

        b64_results = self.convert_to_b64(image)
        end_time = time.time() - start_time

        print(f"Time: {end_time:.2f} seconds")

        return {"status": "success", "data": b64_results, "time": end_time}
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
  example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
resources:
  accelerator: A10G
  cpu: 3500m
  memory: 20Gi
  use_gpu: true
secrets: {}
system_packages:
- ffmpeg
- libsm6
- libxext6
```
</RequestExample>
